{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2Vn26lahhFJd"
   },
   "source": [
    "# EECS 498-007/598-005 Assignment 6-1: Variational AutoEncoders\n",
    "\n",
    "Before we start, please put your name and UMID in following format\n",
    "\n",
    ": Firstname LASTNAME, #00000000   //   e.g.) Justin JOHNSON, #12345678"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZsZYgVWdALCH"
   },
   "source": [
    "**Your Answer:**   \n",
    "Hello WORLD, #XXXXXXXX"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VJZ8AefthL95"
   },
   "source": [
    "\n",
    "# Variational Autoencoder\n",
    "\n",
    "In this notebook, you will implement a variational autoencoder and a conditional variational autoencoder with slightly different architectures and apply them to the popular MNIST handwritten dataset. Recall from lecture (https://web.eecs.umich.edu/~justincj/slides/eecs498/FA2020/598_FA2020_lecture19.pdf), an autoencoder seeks to learn a latent representation of our training images by using unlabeled data and learning to reconstruct its inputs. The *variational autoencoder* extends this model by adding a probabilistic spin to the encoder and decoder, allowing us to sample from the learned distribution of the latent space to generate new images at inference time."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JtA1_PsYhs24"
   },
   "source": [
    "## Setup Code\n",
    "Before getting started, we need to run some boilerplate code to set up our environment, same as previous assignments. You'll need to rerun this setup code each time you start the notebook.\n",
    "\n",
    "First, run this cell load the autoreload extension. This allows us to edit .py source files, and re-import them into the notebook for a seamless editing and debugging experience.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OKXSEQjRh63r"
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eIXWSou6h_S6"
   },
   "source": [
    "### Google Colab Setup\n",
    "Next we need to run a few commands to set up our environment on Google Colab. If you are running this notebook on a local machine you can skip this section.\n",
    "\n",
    "Run the following cell to mount your Google Drive. Follow the link, sign in to your Google account (the same account you used to store this notebook!) and copy the authorization code into the text box that appears below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "evqEGDXRipC-"
   },
   "outputs": [],
   "source": [
    "from google.colab import drive\n",
    "drive.mount('/content/drive')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2y_JbZTlDoai"
   },
   "source": [
    "Now recall the path in your Google Drive where you uploaded this notebook, fill it in below. If everything is working correctly then running the folowing cell should print the filenames from the assignment:\n",
    "\n",
    "```\n",
    "['eecs598', 'gan.py', 'generative_adversarial_networks.ipynb', 'a6_helper.py', 'vae.py', 'variational_autoencoders.ipynb']\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Z8OU2pRkivyc"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# TODO: Fill in the Google Drive path where you uploaded the assignment\n",
    "# Example: If you create a 2020FA folder and put all the files under A6 folder, then '2020FA/A6'\n",
    "GOOGLE_DRIVE_PATH_AFTER_MYDRIVE = None\n",
    "GOOGLE_DRIVE_PATH = os.path.join('drive', 'My Drive', GOOGLE_DRIVE_PATH_AFTER_MYDRIVE)\n",
    "print(os.listdir(GOOGLE_DRIVE_PATH))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GJ7auXOMi4rw"
   },
   "source": [
    "Once you have successfully mounted your Google Drive and located the path to \n",
    "this assignment, run the following cell to allow us to import from the `.py` files of this assignment. If it works correctly, it should print the message:\n",
    "\n",
    "```\n",
    "Hello from vae.py!\n",
    "Hello from a6_helper.py!\n",
    "```\n",
    "\n",
    "as well as the last edit time for the file `vae.py`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5zOOPYIUjCUO"
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(GOOGLE_DRIVE_PATH)\n",
    "\n",
    "import time, os\n",
    "os.environ[\"TZ\"] = \"US/Eastern\"\n",
    "time.tzset()\n",
    "\n",
    "from vae import hello_vae\n",
    "hello_vae()\n",
    "\n",
    "from a6_helper import hello_helper\n",
    "hello_helper()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JuIiv2bhjFoC"
   },
   "source": [
    "Load several useful packages that are used in this notebook:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "sLdT7GSljI0f"
   },
   "outputs": [],
   "source": [
    "from eecs598.grad import rel_error\n",
    "from eecs598.utils import reset_seed\n",
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.nn import init\n",
    "import torchvision\n",
    "import torchvision.transforms as T\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data import sampler\n",
    "import torchvision.datasets as dset\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "# for plotting\n",
    "plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots\n",
    "plt.rcParams['font.size'] = 16\n",
    "plt.rcParams['image.interpolation'] = 'nearest'\n",
    "plt.rcParams['image.cmap'] = 'gray'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_nqWhiLojS8M"
   },
   "source": [
    "We will use GPUs to accelerate our computation in this notebook. Run the following to make sure GPUs are enabled:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RdQhVgi5jVQp"
   },
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "  print('Good to go!')\n",
    "else:\n",
    "  print('Please set GPU via Edit -> Notebook Settings.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bcqRQILRjchz"
   },
   "source": [
    "## Load MNIST Dataset\n",
    "\n",
    "\n",
    "VAEs (and GANs as you'll see in the next notebook) are notoriously finicky with hyperparameters, and also require many training epochs. In order to make this assignment approachable, we will be working on the MNIST dataset, which is 60,000 training and 10,000 test images. Each picture contains a centered image of white digit on black background (0 through 9). This was one of the first datasets used to train convolutional neural networks and it is fairly easy -- a standard CNN model can easily exceed 99% accuracy. \n",
    "\n",
    "To simplify our code here, we will use the PyTorch MNIST wrapper, which downloads and loads the MNIST dataset. See the [documentation](https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py) for more information about the interface. The default parameters will take 5,000 of the training examples and place them into a validation dataset. The data will be saved into a folder called `MNIST_data`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mExnwvTXjcF_"
   },
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "\n",
    "mnist_train = dset.MNIST('./MNIST_data', train=True, download=True,\n",
    "                           transform=T.ToTensor())\n",
    "loader_train = DataLoader(mnist_train, batch_size=batch_size,\n",
    "                          shuffle=True, drop_last=True, num_workers=2)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CwDmYBjdhTrM"
   },
   "source": [
    "## Visualize dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Q2X_21cTwsox"
   },
   "source": [
    "It is always a good idea to look at examples from the dataset before working with it. Let's visualize the digits in the MNIST dataset. We have defined the function `show_images` in `a6_helper.py` that we call to visualize the images.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3JMbbxMkwrYg"
   },
   "outputs": [],
   "source": [
    "from a6_helper import show_images\n",
    "\n",
    "imgs = loader_train.__iter__().next()[0].view(batch_size, 784)\n",
    "show_images(imgs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tOdQ3r5diEwr"
   },
   "source": [
    "# Fully Connected VAE\n",
    "\n",
    "Our first VAE implementation will consist solely of fully connected layers. We'll take the `1 x 28 x 28` shape of our input and flatten the features to create an input dimension size of 784. In this section you'll define the Encoder and Decoder models in the VAE class of `vae.py` and implement the reparametrization trick, forward pass, and loss function to train your first VAE."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aqMjCTSMi0jX"
   },
   "source": [
    "## FC-VAE Encoder\n",
    "\n",
    "Now lets start building our fully-connected VAE network. We'll start with the encoder, which will take our images as input (after flattening C,H,W to D shape) and pass them through a three Linear+ReLU layers. We'll use this hidden dimension representation to predict both the posterior mu and posterior log-variance using two separate linear layers (both shape (N,Z)). \n",
    "\n",
    "Note that we are calling this the 'logvar' layer because we'll use the log-variance (instead of variance or standard deviation) to stabilize training. This will specifically matter more when you compute reparametrization and the loss function later. \n",
    "\n",
    "*Define an `encoder`, `hidden_dim` (H), `mu_layer`, and `logvar_layer` in the initialization of the VAE class in `vae.py`. Use nn.Sequential to define the encoder, and separate Linear layers for the mu and logvar layers. In all of these layers, H will be a hidden dimension you set and will be the same across all encoder and decoder layers. Architecture for the encoder is described below:*\n",
    "\n",
    "\n",
    " * `Flatten` (Hint: nn.Flatten)\n",
    " * Fully connected layer with input size 784 (`input_size`) and output size H\n",
    " * `ReLU`\n",
    " * Fully connected layer with input_size H and output size H\n",
    " * `ReLU`\n",
    " * Fully connected layer with input_size H and output size H\n",
    " * `ReLU`\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gTuEAFrgkTyt"
   },
   "source": [
    "## FC-VAE Decoder\n",
    "\n",
    "We'll now define the decoder, which will take the latent space representation and generate a reconstructed image. The architecture is as follows: \n",
    "\n",
    " * Fully connected layer with input size as the latent size (Z) and output size H\n",
    " * `ReLU`\n",
    " * Fully connected layer with input_size H and output size H\n",
    " * `ReLU`\n",
    " * Fully connected layer with input_size H and output size H\n",
    " * `ReLU`\n",
    " * Fully connected layer with input_size H and output size 784 (`input_size`)\n",
    " * `Sigmoid`\n",
    " * `Unflatten` (nn.Unflatten)\n",
    "\n",
    "*Define a `decoder` in the initialization of the VAE class in `vae.py`. Like the encoding step, use `nn.Sequential`*  \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aFTb-35TiOZl"
   },
   "source": [
    "## Reparametrization \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SdD23s3Vf70-"
   },
   "source": [
    "Now we'll apply a reparametrization trick in order to estimate the posterior $z$ during our forward pass, given the $\\mu$ and $\\sigma^2$ estimated by the encoder. A simple way to do this could be to simply generate a normal distribution centered at our  $\\mu$ and having a std corresponding to our $\\sigma^2$. However, we would have to backpropogate through this random sampling that is not differentiable. Instead, we sample initial random data $\\epsilon$ from a fixed distrubtion, and compute $z$ as a function of ($\\epsilon$, $\\sigma^2$, $\\mu$). Specifically:\n",
    "\n",
    "$z = \\mu + \\sigma\\epsilon$\n",
    "\n",
    "We can easily find the partial derivatives w.r.t $\\mu$ and $\\sigma^2$ and backpropagate through $z$. If $\\epsilon = \\mathcal{N} (0,1)$, then its easy to verify that the result of our forward pass calculation will be a distribution centered at $\\mu$ with variance $\\sigma^2$.\n",
    "\n",
    "Implement `reparametrization` in `vae.py` and verify your mean and std error are at or less than `1e-4`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "T236XnbhbVH4"
   },
   "outputs": [],
   "source": [
    "reset_seed(0)\n",
    "from vae import reparametrize\n",
    "latent_size = 15\n",
    "size = (1, latent_size)\n",
    "mu = torch.zeros(size)\n",
    "logvar = torch.ones(size)\n",
    "\n",
    "z = reparametrize(mu, logvar)\n",
    "\n",
    "expected_mean = torch.FloatTensor([-0.4363])\n",
    "expected_std = torch.FloatTensor([1.6860])\n",
    "z_mean = torch.mean(z, dim=-1)\n",
    "z_std = torch.std(z, dim=-1)\n",
    "assert z.size() == size\n",
    "\n",
    "print('Mean Error', rel_error(z_mean, expected_mean))\n",
    "print('Std Error', rel_error(z_std, expected_std))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XOfC7oDrkUhl"
   },
   "source": [
    "## FC-VAE Forward\n",
    "\n",
    "Complete the VAE class by writing the forward pass. The forward pass should pass the input image through the encoder to calculate the estimation of mu and logvar, reparametrize to estimate the latent space z, and finally pass z into the decoder to generate an image.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IQ0UXWyMi9Gu"
   },
   "source": [
    "## Loss Function\n",
    "\n",
    "Before we're able to train our final model, we'll need to define our loss function. As seen below, the loss function for VAEs contains two terms: A reconstruction loss term (left) and KL divergence term (right). \n",
    "\n",
    "$-E_{Z~q_{\\phi}(z|x)}[log p_{\\theta}(x|z)] + D_{KL}(q_{\\phi}(z|x), p(z)))$\n",
    "\n",
    "Note that this is the negative of the variational lowerbound shown in lecture--this ensures that when we are minimizing this loss term, we're maximizing the variational lowerbound. The reconstruction loss term can be computed by simply using the binary cross entropy loss between the original input pixels and the output pixels of our decoder (Hint: `nn.functional.binary_cross_entropy`). The KL divergence term works to force the latent space distribution to be close to a prior distribution (we're using a standard normal gaussian as our prior).\n",
    "\n",
    "To help you out, we've derived an unvectorized form of the KL divergence term for you.\n",
    "Suppose that $q_\\phi(z|x)$ is a $Z$-dimensional diagonal Gaussian with mean $\\mu_{z|x}$ of shape $(Z,)$ and standard deviation $\\sigma_{z|x}$ of shape $(Z,)$, and that $p(z)$ is a $Z$-dimensional Gaussian with zero mean and unit variance. Then we can write the KL divergence term as:\n",
    "\n",
    "$D_{KL}(q_{\\phi}(z|x), p(z))) = -\\frac{1}{2} \\sum_{j=1}^{J} (1 + log(\\sigma_{z|x}^2)_{j} - (\\mu_{z|x})^2_{j} - (\\sigma_{z|x})^2_{j}$)\n",
    "\n",
    "It's up to you to implement a vectorized version of this loss that also operates on minibatches.\n",
    "You should average the loss across samples in the minibatch.\n",
    "\n",
    "Implement `loss_function` in `vae.py` and verify your implementation below. Your relative error should be less than or equal to `1e-5`\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vF2ZUj2FjrFa"
   },
   "outputs": [],
   "source": [
    "from vae import loss_function\n",
    "size = (1,15)\n",
    "\n",
    "image = torch.sigmoid(torch.FloatTensor([[2,5], [6,7]]).unsqueeze(0).unsqueeze(0))\n",
    "image_hat = torch.sigmoid(torch.FloatTensor([[1,10], [9,3]]).unsqueeze(0).unsqueeze(0))\n",
    "\n",
    "expected_out = torch.tensor(8.5079)\n",
    "mu, logvar = torch.ones(size), torch.zeros(size)\n",
    "out = loss_function(image, image_hat, mu, logvar)\n",
    "print('Loss error', rel_error(expected_out,out))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wV8fbzenkAXm"
   },
   "source": [
    "\n",
    "## Train a model\n",
    "\n",
    "Now that we have our VAE defined and loss function ready, lets train our model! Our training script is provided  in `a6_helper.py`, and we have pre-defined an Adam optimizer, learning rate, and # of epochs for you to use. \n",
    "\n",
    "Training for 10 epochs should take ~2 minutes and your loss should be less than 120."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "rWaaacNHsfao"
   },
   "outputs": [],
   "source": [
    "num_epochs = 10\n",
    "latent_size = 15\n",
    "from vae import VAE\n",
    "from a6_helper import train_vae\n",
    "input_size = 28*28\n",
    "device = 'cuda'\n",
    "vae_model = VAE(input_size, latent_size=latent_size)\n",
    "vae_model.cuda()\n",
    "for epoch in range(0, num_epochs):\n",
    "  train_vae(epoch, vae_model, loader_train)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JT6Ek-26jjJD"
   },
   "source": [
    "## Visualize results\n",
    "\n",
    "After training our VAE network, we're able to take advantage of its power to generate new training examples. This process simply involves the decoder: we intialize some random distribution for our latent spaces z, and generate new examples by passing these latent space into the decoder. \n",
    "\n",
    "Run the cell below to generate new images! You should be able to visually recognize many of the digits, although some may be a bit blurry or badly formed. Our next model will see improvement in these results. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RhhrsgrMTyTi"
   },
   "outputs": [],
   "source": [
    "z = torch.randn(10, latent_size).to(device='cuda')\n",
    "import matplotlib.gridspec as gridspec\n",
    "vae_model.eval()\n",
    "samples = vae_model.decoder(z).data.cpu().numpy()\n",
    "\n",
    "fig = plt.figure(figsize=(10, 1))\n",
    "gspec = gridspec.GridSpec(1, 10)\n",
    "gspec.update(wspace=0.05, hspace=0.05)\n",
    "for i, sample in enumerate(samples):\n",
    "  ax = plt.subplot(gspec[i])\n",
    "  plt.axis('off')\n",
    "  ax.set_xticklabels([])\n",
    "  ax.set_yticklabels([])\n",
    "  ax.set_aspect('equal')\n",
    "  plt.imshow(sample.reshape(28,28), cmap='Greys_r')\n",
    "  plt.savefig(os.path.join(GOOGLE_DRIVE_PATH,'vae_generation.jpg'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sx3HGSpXk1MY"
   },
   "source": [
    "## Latent Space Interpolation\n",
    "\n",
    "As a final visual test of our trained VAE model, we can perform interpolation in latent space. We generate random latent vectors $z_0$ and $z_1$, and linearly interplate between them; we run each interpolated vector through the trained generator to produce an image.\n",
    "\n",
    "Each row of the figure below interpolates between two random vectors. For the most part the model should exhibit smooth transitions along each row, demonstrating that the model has learned something nontrivial about the underlying spatial structure of the digits it is modeling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XZ_4XsFURmN1"
   },
   "outputs": [],
   "source": [
    "S = 12\n",
    "latent_size = 15\n",
    "device = 'cuda'\n",
    "z0 = torch.randn(S,latent_size , device=device)\n",
    "z1 = torch.randn(S, latent_size, device=device)\n",
    "w = torch.linspace(0, 1, S, device=device).view(S, 1, 1)\n",
    "z = (w * z0 + (1 - w) * z1).transpose(0, 1).reshape(S * S, latent_size)\n",
    "x = vae_model.decoder(z)\n",
    "show_images(x.data.cpu())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wzS_ufzEkhah"
   },
   "source": [
    "# Conditional FC-VAE \n",
    "\n",
    "The second model you'll develop will be very similar to the FC-VAE, but with a slight conditional twist to it. We'll use what we know about the labels of each MNIST image, and *condition* our latent space and image generation on the specific class. Instead of $q_{\\phi} (z|x)$ and $p_{\\phi}(x|z)$ we have $q_{\\phi} (z|x,c)$  and $p_{\\phi}(x|z, c)$\n",
    "\n",
    "This will allow us to do some powerful conditional generation at inference time. We can specifically choose to generate more 1s, 2s, 9s, etc. instead of simply generating new digits randomly."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hle0JuhwklKc"
   },
   "source": [
    "## Define Network with class input\n",
    "\n",
    "Our CVAE architecture will be the same as our FC-VAE architecture, except we'll now add a one-hot label vector to both the x input (in our case, the flattened image dimensions) and the z latent space. \n",
    "\n",
    "If our one-hot vector is called `c`, then `c[label] = 1` and `c = 0` elsewhere.\n",
    "\n",
    "For the `CVAE` class in `vae.py` use the same FC-VAE architecture implemented in the last network with the following modifications:\n",
    "\n",
    "1. Modify the first linear layer of your `encoder` to take in not only the flattened input image, but also the one-hot label vector `c`\n",
    "2. Modify the first layer of your `decoder` to project the latent space + one-hot vector to the `hidden_dim`\n",
    "3. Lastly, implement the `forward` pass to combine the flattened input image with the one-hot vectors (`torch.cat`) before passing them to the `encoder` and combine the latent space with the one-hot vectors (`torch.cat`) before passing them to the `decoder`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bUzKyFI9kp8i"
   },
   "source": [
    "## Train model\n",
    "\n",
    "Using the same training script, let's now train our CVAE! \n",
    "\n",
    "Training for 10 epochs should take ~2 minutes and your loss should be less than 120."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "N1dzKDUsunbD"
   },
   "outputs": [],
   "source": [
    "from vae import CVAE\n",
    "num_epochs = 10\n",
    "latent_size = 15\n",
    "from a6_helper import train_vae\n",
    "input_size = 28*28\n",
    "device = 'cuda'\n",
    "\n",
    "cvae = CVAE(input_size, latent_size=latent_size)\n",
    "cvae.cuda()\n",
    "for epoch in range(0, num_epochs):\n",
    "  train_vae(epoch, cvae, loader_train, cond=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GMAyFBZTkr1Y"
   },
   "source": [
    "## Visualize Results\n",
    "\n",
    "We've trained our CVAE, now lets conditionally generate some new data! This time, we can specify the class we want to generate by adding our one hot matrix of class labels. We use `torch.eye` to create an identity matrix, gives effectively gives us one label for each digit. When you run the cell below, you should get one example per digit. Each digit should be reasonably distinguishable (it is ok to run this cell a few times to save your best results).\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GCfwpz0NALdZ"
   },
   "outputs": [],
   "source": [
    "z = torch.randn(10, latent_size)\n",
    "c = torch.eye(10, 10) # [one hot labels for 0-9]\n",
    "import matplotlib.gridspec as gridspec\n",
    "z = torch.cat((z,c), dim=-1).to(device='cuda')\n",
    "cvae.eval()\n",
    "samples = cvae.decoder(z).data.cpu().numpy()\n",
    "\n",
    "fig = plt.figure(figsize=(10, 1))\n",
    "gspec = gridspec.GridSpec(1, 10)\n",
    "gspec.update(wspace=0.05, hspace=0.05)\n",
    "for i, sample in enumerate(samples):\n",
    "  ax = plt.subplot(gspec[i])\n",
    "  plt.axis('off')\n",
    "  ax.set_xticklabels([])\n",
    "  ax.set_yticklabels([])\n",
    "  ax.set_aspect('equal')\n",
    "  plt.imshow(sample.reshape(28, 28), cmap='Greys_r')\n",
    "  plt.savefig(os.path.join(GOOGLE_DRIVE_PATH,'conditional_vae_generation.jpg'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vW8GmNSwY5Jx"
   },
   "source": [
    "## Final Check\n",
    "\n",
    "Make sure all your training results (loss + images) are saved in the notebook. You can run \"Runtime -> Restart and run all...\" to double check before submitting"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "variational_autoencoders.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
